from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from datasets import load_dataset
import torch
from tqdm import tqdm
import math
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Set the GPU 1 to use
from modify_gpt2_model import modify_gpt2_model_attention as modify_gpt2

model_name = 'gpt2'  # Can be changed to 'gpt2-medium', 'gpt2-large', or 'gpt2-xl'
store_path = None
model = GPT2LMHeadModel.from_pretrained(model_name, cache_dir=store_path, torch_dtype=torch.bfloat16)
tokenizer = GPT2TokenizerFast.from_pretrained(model_name, cache_dir=store_path)
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='validation', trust_remote_code=True)

def concatenate_texts(examples):
    return {'text': [' '.join(examples['text'])]}

concat_text = dataset.map(concatenate_texts, batched=True, batch_size=-1)['text'][0]

input_ids = tokenizer.encode(concat_text, return_tensors='pt')

max_length = 1024
stride = 512

def validate_model(model):
    model.eval()
    loss_cnt = 0
    batch_idx = 0
    for i in tqdm(range(0, input_ids.size(1), stride)):
        batch_idx += 1
        begin_loc = max(i + stride - max_length, 0)
        end_loc = min(i + stride, input_ids.size(1))
        trg_len = end_loc - i
        input_ids_chunk = input_ids[:, begin_loc:end_loc].to(device)
        target_ids = input_ids_chunk.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            outputs = model(input_ids_chunk, labels=target_ids)
            loss_cnt += outputs.loss.to(torch.float32)
    perplexity = torch.exp(loss_cnt / batch_idx)
    return perplexity.item()
ppl = validate_model(model)
print(f'Original Validation Perplexity: {ppl:.5f}')


remove_indices = list(range(len(model.transformer.h)))
model = GPT2LMHeadModel.from_pretrained(model_name, cache_dir=store_path, torch_dtype=torch.bfloat16)
model_removed = modify_gpt2(model, remove_indices)
model_removed.to(device)

ppl = validate_model(model_removed)
print(f'remove All Layer: Validation Perplexity: {ppl:.5f}')    